
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% GE_save_source_epochs.m
% 
% SA 2018-03-30
% SA 2018-04-08
%
% Converting source time course epochs created by mne_epochs2mat
% into a suitable form to be given as input to the SVM classifier
%
% Modified by: David Sorensen 2021-03-16
%	--Added vector normalization across channels for each timepoint
%	--Deletes the .epoch and desc.mat files after they have been read to conserve disk space
% 

clear all;
close all;

svm_dir = '/autofs/space/clive_001/users/adriana/GE_SVM';

%SubjectNames = {'UAG_03'}; 

%SubjectNames = {'GE_01'}; 
%SubjectNames = {'GE_05'}; 
%SubjectNames = {'GE_06'}; 
%SubjectNames = {'GE_08'}; 
%SubjectNames = {'GE_20'}; 
%SubjectNames = { 'GE_12','GE_13', 'GE_14', 'GE_15', 'GE_16', 'GE_17', 'GE_18', 'GE_19', 'GE_20', 'GE_21', 'GE_22', 'GE_23', 'GE_24'}; %  
%SubjectNames = {'GE_17', 'GE_18', 'GE_19'};
SubjectNames = {'GE_05', 'GE_06', 'GE_07', 'GE_08', 'GE_09', 'GE_10', 'GE_11', 'GE_12', 'GE_13', 'GE_14', 'GE_15', 'GE_16', 'GE_17', 'GE_18', 'GE_19', 'GE_20', 'GE_21', 'GE_22', 'GE_23', 'GE_24'};

%SubjectFolder=char('_test');
SubjectFolder=char('');
n_subjects = length(SubjectNames);

%blocks = {'B1', 'B2','B3', 'B4','B5', 'B6','B7','B8'};
%blocks = {'B1', 'B2', 'B4','B5', 'B6','B7','B8','B9', 'B10', 'B11','B12', 'B13','B14','B15','B16'};
blocks = {'B1', 'B2', 'B3','B4','B5', 'B6','B7','B8','B9', 'B10', 'B11','B12', 'B13','B14','B15','B16'};

%blocks = {'B1', 'B3', 'B4','B5', 'B6'};
%blocks = {'B1', 'B2', 'B3', 'B4', 'B6'};
n_blocks = length(blocks);
%roisets = {'SMG','LIFG'};
%roisets = {'pMTG'};
%roisets = {'newSMG', 'newpMTG', 'TempPole', 'FrontalPole', 'CentGyri'};
%roisets = {'L_MTG1', 'L_MTG2', 'R_MTG1', 'R_MTG2'};
roisets = {'L_CC_1-lh', 'L_cMFG_1-lh', 'L_ParaHip_1-lh', 'L_ParsTri_1-lh', 'L_postCG_1-lh', 'R_ParsOrb_1-rh', 'R_postCG_2-rh', 'R_preCG_1-rh', 'R_preCG_2-rh', 'R_SFG_1-rh', 'L_ITG_1-lh', 'L_ITG_2-lh', 'L_MTG_1-lh', 'L_MTG_2-lh', 'L_ParsOrb_1-lh', 'L_SMG_1-lh', 'L_SPC_1-lh', 'L_STG_2-lh', 'L_STG_1-lh', 'L_TPol_1-lh', 'R_AG_1-rh', 'R_ITG_1-rh', 'R_ITG_2-rh', 'R_ITG_3-rh', 'R_LOC_1-rh', 'R_MTG_1-rh', 'R_MTG_2-rh', 'R_MTG_3-rh', 'R_MTG_4-rh', 'R_postCG_3-rh', 'R_postCG_4-rh', 'R_SFG_2-rh', 'R_STG_1-rh', 'L_SFG_1-lh', 'R_postCG_1-rh', 'R_postCG_5-rh', 'R_STG_2-rh', 'R_STG_3-rh', 'R_SFG_3-rh'};

n_roisets = length(roisets);

condition_tag = 'pos'; % 'nvs'|'lvn'|'pos'|'tokens'|'biphone'
%CONDITIONS.Trigger coding scheme

%conditions = {'1', '2', '3','4'};
%conditions = {'5', '6', '7','8'};
%conditions = {'1', '2', '11','22'};
% conditions = {'3', '4', '33','44','1', '2', '11','22','5', '6', '55','66'};
% conditions = {'1', '2', '10','100', '20','200','3', '4', '30','300', '40','400', '5', '6', '50', '60', '500', '600'};
%conditions = {'1', '2', '11','22','3', '4', '33','44','5', '6', '55', '66'};

%conditions = {'1', '2', '10','20','3', '4', '30','40','5', '6', '50', '60'};

switch condition_tag
    case 'nvs'
        conditions = {'1', '2', '11','22','3', '4', '33','44','5', '6', '55', '66'};
    case 'lvn'
        conditions = {'1', '2', '10','100', '20','200','3', '4', '30','300', '40','400', '5', '6', '50', '60', '500', '600'};
    case 'pos'
        conditions = {'1', '1001', '1003',  '2', '2001', '2003', '3', '3001', '3003', '4', '4001', '4003', '5', '5001', '5003', '6', '6001', '6003'};
    case 'tokens'
        %tokens used for when data will be split randomly: half, third, nvl(lexical train on hubs) and svn (train on hubs, test on all neighbors)
        conditions = {'101', '211' , '221', '231', '311', '321', '331', '102', '212', '222', '232', '312', '322', '332', '103', '213', '223', '233', '313', '323', '333', '104', '214', '224', '234', '314', '324', '334', 105', '215', '225', '235', '315', '325', '335', '106', '216', '226', '236', '316', '326', '336'};
	case 'biphone'
		conditions = {'1', '2', '3', '4', '5', '6', '1201', '2201', '3201', '4201', '5201', '6201', '1202', '2202', '3202', '4202', '5202', '6202'};
%{'1', '2', '3', '4', '5', '6', '1201', '2201', '3201', '4201', '5201', '6201', '1202', '2202', '3202', '4202', '5202', '6202', '1301', '2301', '3301', '4301', '5301', '6301', '1302', '2302', '3302', '4302', '5302', '6302'};
end

n_cond = length(conditions);
no_outputs = {};
%n_cond = 3;

for i_subject = 1:n_subjects %each subject separately, for now..
    %i_subject = 1;
    subject = char(SubjectNames(i_subject));
    
    % separate *_all.mat file for each ROI: 
    for i_roiset = 1:n_roisets
      roiset = char(roisets(i_roiset))

      %outfile = sprintf('%s/%s/%s_%s_all.mat', svm_dir, subject, subject, roiset)

      i_all_epochs = 0;
      %max_all_epochs = 90;
      %max_times = 1100;
      %max_chs = 4;
      %data = zeros(n_times, n_chs, n_all__epochs);
      
      % combine data from all blocks (i.e., runs) and all conditions into
      % a single (subject- and ROI-specific) *_all.mat file:
      for i_block = 1:n_blocks
          for i_cond = 1:n_cond
              try
                  %                     infile = sprintf('%s/%s/epochs/%s_%s_%s_%s_desc.mat', ...
                  %                         svm_dir, subject, subject, char(blocks(i_block)), roiset, char(conditions(i_cond)))
                  infile = sprintf('%s/%s%s/epochs/%s_%s_%s_%s_desc.mat', ...
                      svm_dir, subject,SubjectFolder, subject, char(blocks(i_block)), roiset, char(conditions(i_cond)));
                  
                  %load('LIFG_b1_desc.mat');
                  load(infile);
                  %n_times = MNE_epoch_info.epochs(1,5);
                  t_min = single(MNE_epoch_info.epochs(1,4))*0.001;
                  n_samples = MNE_epoch_info.epochs(1,5);
                  sfreq = MNE_epoch_info.sfreq;
                  n_chs = MNE_epoch_info.nchan;
				  if n_chs < 32
					warning('Subject %s %s only contains %d vertices', subject, roiset, n_chs)
				  end
                  n_epochs =  MNE_epoch_info.nepoch;
                    
                    % Initialize the data array dimensions:
                    if i_block==1 && i_cond==1
                        data = zeros(n_samples, n_chs, n_epochs);
                        this_data = zeros(n_samples,n_chs);
						condition_index = zeros(1, n_epochs);
                    end % if
                    
                    precision = 'single';
                    machinefmt = 'b';
                    fid = fopen(MNE_epoch_info.epoch_file, 'r');
                    for i_epoch = 1:n_epochs % read one epoch at a time
                      i_all_epochs = i_all_epochs + 1;
                      %condition_index(i_all_epochs) = i_cond;
                      condition_index(i_all_epochs) = str2num(conditions{i_cond}); %fixed by adriana 4/18/18
                      for i_time = 1:n_samples 
                        % read all channels (i.e. ROIs), for one time point
                        times(i_time) = t_min + single(i_time - 1) / sfreq;
                        %data(i_time,:,i_all_epochs) = fread(fid,n_chs,precision,machinefmt);
                        this_data(i_time,:) = fread(fid,n_chs,precision,machinefmt);
                      end % i_time

                      % Baseline correction using pre-stimulus average
                      b_min = 1;    % first time point index
                      b_max = round(1 - t_min * sfreq); % time point index for zero-time
                      for i_ch = 1:n_chs
                          baseline = mean(this_data(b_min:b_max,i_ch),1);
                          this_data(:,i_ch) = this_data(:,i_ch) - baseline;
                      end % i_ch

					  % Vector normalization across channels
					  for i_time = 1:n_samples
						this_data(i_time,:) = normalize(this_data(i_time,:), 'norm');
					  end 

                      data(:,:,i_all_epochs) = this_data;

                    end % i_epoch
                    fclose(fid);

					delete(MNE_epoch_info.epoch_file)
		   			clear MNE_epoch_info
		   			delete(infile)

	catch ME
		if strcmp(ME.identifier, 'MATLAB:load:couldNotReadFile')
			raw_inverse_fname = sprintf('%s/%s%s/roiraw/%s_B%d_%s_raw.fif', svm_dir, subject, SubjectFolder, subject, i_block, roiset);
			if ~isfile(raw_inverse_fname)
				error_message = sprintf('No inverse timecourse for %s B%d %s', subject, i_block, roiset);
				error(error_message);
			else
				warning_message = sprintf('Missing events of type %s for %s B%d %s', conditions{i_cond}, subject, i_block, roiset);
				warning(warning_message);
			end
		else
			rethrow(ME);
		end
	end %end try
                   
                

           end % i_cond
      end % i_run


      n_all_epochs = i_all_epochs;

	  if n_all_epochs > 0
      %epoch_data = data(:,1:n_chs-1,:); % use this if STI channel needs to be removed
      	epoch_data = data;
      
      	epochfile = sprintf('%s/%s%s/epochs/%s_%s_%s_all.mat', svm_dir, subject, SubjectFolder, subject, roiset, condition_tag)
      	save(epochfile, 'epoch_data', 'condition_index','times');
	  else
        no_outputs{end+1} = sprintf('No data for %s %s. No output file written', subject, roiset);
	  end
      
%       for i_ch = 1:n_chs  %adriana commented out for testing
%           %thisfig = 10*i_roiset + i_ch; %
%           thisfig = 10*i_roiset;
%           for i_epoch = 1:n_all_epochs
%               figure(thisfig); % all source waveforms
%               plot(times, epoch_data(:,i_ch,i_epoch));
%               hold on;
%           end % i_epoch
%       end % i_ch
%       hold off;
         
%      for i_ch = 1:n_chs
%          %roifig = 100*i_roiset;
%          roifig = 1000
%          figure(roifig); % averaged source waveforms
%          %plot(times, mean(data(:,i_ch,:),3));
%          plot(times, mean(epoch_data(:,i_ch,:),3));
%          hold on;
%          i_ch
%          epoch_data(1,i_ch,1)
%      end % i_ch
     
    end % i_roiset


%    for i_epoch = 1:n_all_epochs
%      figure(1)
%      plot(times, data(:,1,i_epoch));
%      hold on;
%    end % i_epoch
%    figure(2)
%    plot(times, mean(data(:,1,:),3));
    
    
end % i_subject

for k = 1:length(no_outputs)
	disp(no_outputs{k});
end

% figure(roifig); % averaged source waveforms
% %legend(roisets);
% %legend(repmat(roisets,1,n_subjects));
% ii = 0;
% for i_roiset = 1:n_roisets
%     for i_ch = 1:n_chs
%         ii = ii + 1;
%         roiset_legend{ii} = char(roisets(i_roiset));
%     end % i_ch
% end % i_roiset
% 
% legend(roiset_legend);


